// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package internal

import (
	"bufio"
	"bytes"
	"compress/zlib"
	"io"
	"testing"

	"github.com/klauspost/compress/zstd"
	"github.com/pingcap/tidb/pkg/parser/mysql"
	"github.com/pingcap/tidb/pkg/server/internal/testutil"
	"github.com/pingcap/tidb/pkg/server/internal/util"
	"github.com/stretchr/testify/require"
)

func BenchmarkPacketIOWrite(b *testing.B) {
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		var outBuffer bytes.Buffer
		pkt := &PacketIO{bufWriter: bufio.NewWriter(&outBuffer)}
		_ = pkt.WritePacket([]byte{0x6d, 0x44, 0x42, 0x3a, 0x35, 0x36, 0x0, 0x0, 0x0, 0xfc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x68, 0x54, 0x49, 0x44, 0x3a, 0x31, 0x30, 0x38, 0x0, 0xfe})
	}
}

func TestPacketIOWrite(t *testing.T) {
	// Test write one packet
	var outBuffer bytes.Buffer
	pkt := &PacketIO{bufWriter: bufio.NewWriter(&outBuffer)}
	err := pkt.WritePacket([]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03})
	require.NoError(t, err)
	err = pkt.Flush()
	require.NoError(t, err)
	require.Equal(t, []byte{0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, outBuffer.Bytes())

	// Test write more than one packet
	outBuffer.Reset()
	largeInput := make([]byte, mysql.MaxPayloadLen+4)
	pkt = &PacketIO{bufWriter: bufio.NewWriter(&outBuffer)}
	err = pkt.WritePacket(largeInput)
	require.NoError(t, err)
	err = pkt.Flush()
	require.NoError(t, err)
	res := outBuffer.Bytes()
	require.Equal(t, byte(0xff), res[0])
	require.Equal(t, byte(0xff), res[1])
	require.Equal(t, byte(0xff), res[2])
	require.Equal(t, byte(0), res[3])
}

func TestPacketIOWriteCompressed(t *testing.T) {
	var testdata, outBuffer bytes.Buffer

	seq := uint8(0)
	pkt := &PacketIO{
		bufWriter:            bufio.NewWriter(&outBuffer),
		compressionAlgorithm: mysql.CompressionZlib,
		compressedWriter:     newCompressedWriter(&testdata, mysql.CompressionZlib, &seq),
	}

	payload := bytes.Repeat([]byte{'A'}, 16*1024*1024)
	err := pkt.WritePacket(payload)
	require.NoError(t, err)

	err = pkt.Flush()
	require.NoError(t, err)

	compressedLength := []byte{0x18, 0x4, 0x0} // 1048 bytes
	packetNr := []byte{0x0}
	uncompressedLength := []byte{0x0, 0x0, 0x10} // 1048576 bytes

	require.Equal(t, compressedLength, testdata.Bytes()[:3])
	require.Equal(t, packetNr, testdata.Bytes()[3:4])
	require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])
}

func TestPacketIORead(t *testing.T) {
	t.Run("uncompressed", func(t *testing.T) {
		var inBuffer bytes.Buffer
		_, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
		require.NoError(t, err)
		// Test read one packet
		brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
		pkt := NewPacketIO(brc)
		readBytes, err := pkt.ReadPacket()
		require.NoError(t, err)
		require.Equal(t, uint8(1), pkt.sequence)
		require.Equal(t, []byte{0x01}, readBytes)

		inBuffer.Reset()
		buf := make([]byte, mysql.MaxPayloadLen+9)
		buf[0] = 0xff
		buf[1] = 0xff
		buf[2] = 0xff
		buf[3] = 0
		buf[2+mysql.MaxPayloadLen] = 0x00
		buf[3+mysql.MaxPayloadLen] = 0x00
		buf[4+mysql.MaxPayloadLen] = 0x01
		buf[7+mysql.MaxPayloadLen] = 0x01
		buf[8+mysql.MaxPayloadLen] = 0x0a

		_, err = inBuffer.Write(buf)
		require.NoError(t, err)
		// Test read multiple packets
		brc = util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
		pkt = NewPacketIO(brc)
		readBytes, err = pkt.ReadPacket()
		require.NoError(t, err)
		require.Equal(t, uint8(2), pkt.sequence)
		require.Equal(t, mysql.MaxPayloadLen+1, len(readBytes))
		require.Equal(t, byte(0x0a), readBytes[mysql.MaxPayloadLen])
	})
	t.Run("compressed_short", func(t *testing.T) {
		var inBuffer bytes.Buffer
		_, err := inBuffer.Write([]byte{0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			0x23, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x73, 0x65, 0x6c, 0x65,
			0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f,
			0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c,
			0x69, 0x6d, 0x69, 0x74, 0x20, 0x31})

		require.NoError(t, err)
		// Test read one packet
		brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
		pkt := NewPacketIO(brc)
		pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
		readBytes, err := pkt.ReadPacket()
		require.NoError(t, err)
		require.Equal(t, uint8(1), pkt.sequence)

		// 03 00 01 73 65 6c 65 63  74 20 40 40 76 65 72 73  |...select @@vers|
		// 69 6f 6e 5f 63 6f 6d 6d  65 6e 74 20 6c 69 6d 69  |ion_comment limi|
		// 74 20 31                                          |t 1|
		expected := []byte{0x3, 0x0, 0x1, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20,
			0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63,
			0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69,
			0x74, 0x20, 0x31}

		require.Equal(t, expected, readBytes)
	})
	t.Run("zlib", func(t *testing.T) {
		var inBuffer bytes.Buffer

		// Header: 7 bytes (compressed_length<3>, packetnr<1>, compressed_length<3>)
		// Payload: Starting with the magic nr for zlib: 0x78
		_, err := inBuffer.Write([]byte{0x39, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00,
			0x78, 0x5e, 0x73, 0x65, 0x60, 0x60, 0x60, 0x0e, 0x76, 0xf5, 0x71,
			0x75, 0x0e, 0x51, 0x30, 0x54, 0xd0, 0xd7, 0x52, 0x48, 0x4c, 0x4a,
			0x4e, 0x49, 0x4d, 0x4b, 0xcf, 0xc8, 0xcc, 0xca, 0xce, 0xc9, 0xcd,
			0xcb, 0x2f, 0x28, 0x2c, 0x2a, 0x2e, 0x29, 0x2d, 0x2b, 0xaf, 0xa8,
			0xac, 0x8a, 0xc7, 0x2d, 0xa5, 0xa0, 0xa5, 0x0f, 0x00, 0x59, 0xd8,
			0x1a, 0x09})

		require.NoError(t, err)
		// Test read one packet
		brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
		pkt := NewPacketIO(brc)
		pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
		readBytes, err := pkt.ReadPacket()
		require.NoError(t, err)
		require.Equal(t, uint8(1), pkt.sequence)

		// 03 53 45 4c 45 43 54 20  31 20 2f 2a 20 61 62 63  |.SELECT 1 /* abc|
		// 64 65 66 67 68 69 6a 6b  6c 6d 6e 6f 70 71 72 73  |defghijklmnopqrs|
		// 74 75 76 77 78 79 7a 5f  61 62 63 64 65 66 67 68  |tuvwxyz_abcdefgh|
		// 69 6a 6b 6c 6d 6e 6f 70  71 72 73 74 75 76 77 78  |ijklmnopqrstuvwx|
		// 79 7a 20 2a 2f                                    |yz */|

		expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31, 0x20,
			0x2f, 0x2a, 0x20, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68,
			0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73,
			0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x5f, 0x61, 0x62, 0x63,
			0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
			0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
			0x7a, 0x20, 0x2a, 0x2f}

		require.Equal(t, expected, readBytes)
	})
	t.Run("zstd", func(t *testing.T) {
		var inBuffer bytes.Buffer
		// Header: 7 bytes (compressed_length<3>, packetnr<1>, compressed_length<3>)
		// Payload: Starting with the magic nr for zstd: 0x28, 0xb5, 0x2f, 0xfd
		_, err := inBuffer.Write([]byte{
			0x40, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x28, 0xb5, 0x2f, 0xfd, 0x20,
			0x49, 0xbd, 0x01, 0x00, 0xf4, 0x02, 0x45, 0x00, 0x00, 0x00, 0x03, 0x53,
			0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31, 0x20, 0x2f, 0x2a, 0x20, 0x61,
			0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d,
			0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
			0x7a, 0x5f, 0x20, 0x2a, 0x2f, 0x01, 0x00, 0x74, 0x7b, 0x96, 0x01})

		require.NoError(t, err)
		// Test read one packet
		brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
		pkt := NewPacketIO(brc)
		pkt.SetCompressionAlgorithm(mysql.CompressionZstd)
		readBytes, err := pkt.ReadPacket()
		require.NoError(t, err)
		require.Equal(t, uint8(1), pkt.sequence)

		// 03 53 45 4c 45 43 54 20  31 20 2f 2a 20 61 62 63  |.SELECT 1 /* abc|
		// 64 65 66 67 68 69 6a 6b  6c 6d 6e 6f 70 71 72 73  |defghijklmnopqrs|
		// 74 75 76 77 78 79 7a 5f  61 62 63 64 65 66 67 68  |tuvwxyz_abcdefgh|
		// 69 6a 6b 6c 6d 6e 6f 70  71 72 73 74 75 76 77 78  |ijklmnopqrstuvwx|
		// 79 7a 20 2a 2f                                    |yz */|

		expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31, 0x20,
			0x2f, 0x2a, 0x20, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68,
			0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73,
			0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x5f, 0x61, 0x62, 0x63,
			0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
			0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
			0x7a, 0x20, 0x2a, 0x2f}

		require.Equal(t, expected, readBytes)
	})
}

// Small payloads of less than minCompressLength (50 bytes) don't get actually
// compressed, but have the same header as a compressed packet.
//
// Header:
// 0a 00 00   Compressed length: 10
// 00         Packetnr: 0
// 00 00 00   Uncompressed length: 0 (meaning not compressed)
//
// Payload:
// 74 65 73 74 5f 73 68 6f 72 74  test_short
func TestCompressedWriterShort(t *testing.T) {
	var testdata bytes.Buffer
	payload := []byte("test_short")
	seq := uint8(0)

	cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq)
	cw.Write(payload)
	cw.Flush()

	// Test Header
	compressedLength := []byte{0xa, 0x0, 0x0}
	packetNr := []byte{0x0}
	uncompressedLength := []byte{0x0, 0x0, 0x0}
	require.Equal(t, compressedLength, testdata.Bytes()[:3])
	require.Equal(t, packetNr, testdata.Bytes()[3:4])
	require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])

	// Payload, making sure it isn't compressed.
	require.Equal(t, payload, testdata.Bytes()[7:])
}

func TestCompressedWriterLong(t *testing.T) {
	t.Run("zlib", func(t *testing.T) {
		var testdata, decoded bytes.Buffer
		payload := []byte("test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib")
		seq := uint8(0)

		cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq)
		cw.Write(payload)
		cw.Flush()

		// Header:
		// 18 00 00   Compressed length: 24
		// 00         Packetnr: 0
		// 45 00 00   Uncompressed length: 69
		compressedLength := []byte{0x18, 0x0, 0x0}
		packetNr := []byte{0x0}
		uncompressedLength := []byte{0x45, 0x0, 0x0}
		require.Equal(t, compressedLength, testdata.Bytes()[:3])
		require.Equal(t, packetNr, testdata.Bytes()[3:4])
		require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])

		// Payload:
		r, err := zlib.NewReader(bytes.NewReader(testdata.Bytes()[7:]))
		require.NoError(t, err)
		io.Copy(&decoded, r)
		require.Equal(t, payload, decoded.Bytes())
	})

	t.Run("zstd", func(t *testing.T) {
		var testdata bytes.Buffer
		payload := []byte("test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd")
		seq := uint8(0)

		cw := newCompressedWriter(&testdata, mysql.CompressionZstd, &seq)
		cw.Write(payload)
		cw.Flush()

		// Header:
		// 1e 00 00   Compressed length: 30
		// 00         Packetnr: 0
		// 45 00 00   Uncompressed length: 69
		compressedLength := []byte{0x1e, 0x0, 0x0}
		packetNr := []byte{0x0}
		uncompressedLength := []byte{0x45, 0x0, 0x0}
		require.Equal(t, compressedLength, testdata.Bytes()[:3])
		require.Equal(t, packetNr, testdata.Bytes()[3:4])
		require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])

		// Payload:
		d, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(0))
		require.NoError(t, err)
		decoded, err := d.DecodeAll(testdata.Bytes()[7:], nil)
		require.NoError(t, err)
		require.Equal(t, payload, decoded)
	})
}

// TestCompressedReaderShort test a compressed protocol packet that has an uncompressed
// length of 0, which means the actual payload isn't compressed.
func TestCompressedReaderShort(t *testing.T) {
	// payload: 7 bytes compressed header, 37 bytes uncompressed payload
	payload := []byte{0x25, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x0, 0x0, 0x3,
		0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72,
		0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74,
		0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}
	r := bytes.NewReader(payload)
	seq := uint8(0)
	cr := newCompressedReader(r, mysql.CompressionZlib, &seq)

	// Read 4 byte header from the payload. This is the regular packet header,
	// not the compressed header.
	header := make([]byte, 4)
	_, err := cr.Read(header)
	require.NoError(t, err)
	require.Equal(t, []byte{0x21, 0x0, 0x0, 0x0}, header)

	length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
	sequence := header[3]
	require.Equal(t, 33, length)
	require.Equal(t, uint8(0), sequence)

	data := make([]byte, length)
	_, err = cr.Read(data)
	require.NoError(t, err)
	expected := []byte{0x3, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40,
		0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d,
		0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}
	require.Equal(t, expected, data)
}

func TestCompressedReaderLong(t *testing.T) {
	t.Run("zlib", func(t *testing.T) {
		payload := []byte{0x19, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x78, 0x5e, 0x9b,
			0xc1, 0xc0, 0xc0, 0xc0, 0x1c, 0xec, 0xea, 0xe3, 0xea, 0x1c, 0xa2,
			0xa0, 0xe4, 0x38, 0xa8, 0x80, 0x12, 0x0, 0xbe, 0xe6, 0x26, 0xce}
		r := bytes.NewReader(payload)
		seq := uint8(0)
		cr := newCompressedReader(r, mysql.CompressionZlib, &seq)
		header := make([]byte, 4)
		_, err := cr.Read(header)
		require.NoError(t, err)
		require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header)

		length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
		sequence := header[3]
		require.Equal(t, 152, length)
		require.Equal(t, uint8(0), sequence)

		data := make([]byte, length)
		_, err = cr.Read(data)
		require.NoError(t, err)
		expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22}
		require.Equal(t, expected, data)
	})
	t.Run("zstd", func(t *testing.T) {
		payload := []byte{0x1f, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x28, 0xb5, 0x2f, 0xfd,
			0x20, 0x9c, 0xb5, 0x0, 0x0, 0x78, 0x98, 0x0, 0x0, 0x0, 0x3, 0x53,
			0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, 0x41, 0x22, 0x1, 0x0, 0xa,
			0xa, 0x28, 0x1}
		r := bytes.NewReader(payload)
		seq := uint8(0)
		cr := newCompressedReader(r, mysql.CompressionZstd, &seq)
		header := make([]byte, 4)
		_, err := cr.Read(header)
		require.NoError(t, err)
		require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header)

		length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
		sequence := header[3]
		require.Equal(t, 152, length)
		require.Equal(t, uint8(0), sequence)

		data := make([]byte, length)
		_, err = cr.Read(data)
		require.NoError(t, err)
		expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
			0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22}
		require.Equal(t, expected, data)
	})
}

// MariaDB Java Connecter 2.X would generate wrong sequence number for the sub header.
// TiDB should be compatible with it.
//
// MySQL Compressed Protocol Header:
// 0e 00 00   					Compressed length
// 00         					Compressed Packetnr
// 00 00 00   					Uncompressed length
//
// MySQL Protocol Header:
// 0a 00 00   					Payload length
// 01         					Packet Sequence Number (it should be 0x00, but MariaDB Connector/J 2.x sets it to 0x01)
// 03							COM_QUERY
// 73 65 6c 65 63 74 20 31 3b   "select 1;"
func TestSubHeaderWithWrongSequenceNumber(t *testing.T) {
	var inBuffer bytes.Buffer
	_, err := inBuffer.Write([]byte{0x0e, 0x00, 0x00, 0x00, 0x00, 0x00,
		0x00, 0x0a, 0x00, 0x00, 0x01, 0x03, 0x73, 0x65, 0x6c,
		0x65, 0x63, 0x74, 0x20, 0x31, 0x3b})
	require.NoError(t, err)
	// Test read one packet
	brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
	pkt := NewPacketIO(brc)
	pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
	readBytes, err := pkt.ReadPacket()
	require.NoError(t, err)
	require.Equal(t, uint8(1), pkt.sequence)
	require.Equal(t, uint8(1), pkt.compressedSequence)
	require.Equal(t, []byte{0x03, 0x73, 0x65, 0x6c, 0x65, 0x63,
		0x74, 0x20, 0x31, 0x3b}, readBytes)
}
